Graph vs. Discrete Geodesic Algorithm on Moons

In 13-graph-geodesics.ipynb, we saw that the graph approach gives promising results on the cone dataset. In this notebook, we use the same approach and directly compare it with the discrete geodesic algorithm from 08-discrete-geodesics.ipynb. We will see that the graph approach works especially well on datasets, where the Euclidean interpolation leads to bad local minima. Also, we use the stochastic Riemannian metric this time and see that the graph approach still works well.

Imports and plotting library setup

In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import tensorflow as tf
from copy import deepcopy
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot

# Set up plotly
init_notebook_mode(connected=True)
layout = go.Layout(
    width=700,
    height=500,
    margin=go.Margin(l=60, r=60, b=40, t=20),
    showlegend=False
)
config={'showLink': False}
colorscale=[[0.0, '#3595E3'], [1.0, '#2D4366']]

# Make results completely deterministic
seed = 0
np.random.seed(seed)
tf.set_random_seed(seed)
/Users/kilian/dev/tum/2018-mlic-kilian/venv/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters

Make the Moons

We use 10,000 points for training and plot 1,000 points of these for the figures.

In [2]:
from sklearn.datasets import make_moons
from sklearn.preprocessing import MinMaxScaler

x, y = make_moons(n_samples=10000, noise=0.075, random_state=seed)
min_max_scaler = MinMaxScaler((-1, 1))
x = min_max_scaler.fit_transform(x)
x_plot, y_plot = x[:1000, :], y[:1000]

scatter_plot = go.Scatter(
    x = x_plot[:, 0],
    y = x_plot[:, 1],
    mode = 'markers',
    marker = {'color': y_plot, 'colorscale': colorscale}
)
data = [scatter_plot]
iplot(go.Figure(data=data, layout=layout), config=config)

Create the VAE

For our encoder and decoder, we use simple neural networks with 2 hidden layers of 10 neurons each.

In [3]:
from tensorflow.python.keras import Sequential, Model
from tensorflow.python.keras.layers import Dense, Input, Lambda, Flatten
from src.vae import VAE
from src.rbf import RBFLayer

input_dim = 2
latent_dim = 2

# Create the encoder models
enc_input = Input((input_dim,))
enc_shared = Sequential([
    Dense(10, activation='softplus'),
    Dense(10, activation='softplus')
])
enc_mean = Model(enc_input, Dense(latent_dim, activation='linear')(
    enc_shared(enc_input)))
enc_var = Model(enc_input, Dense(latent_dim, activation='softplus')(
    enc_shared(enc_input)))

# Create the decoder models
dec_input = Input((latent_dim,))
dec_mean = Sequential([
    Dense(10, activation='softplus'),
    Dense(10, activation='softplus'),
    Dense(input_dim, activation='linear')
])
dec_mean = Model(dec_input, dec_mean(dec_input))

# Build the RBF network
num_centers = 20
a = 0.75
rbf = RBFLayer([input_dim], num_centers)
dec_var = Model(dec_input, rbf(dec_input))

vae = VAE(enc_mean, enc_var, dec_mean, dec_var, dec_stddev=0.1)
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.

Train the VAE

In [4]:
history = vae.model.fit(x,
                        shuffle=True,
                        epochs=20,
                        batch_size=100,
                        verbose=0)
# Plot the losses
data = [go.Scatter(y=history.history['loss'])]
plot_layout = deepcopy(layout)
plot_layout['xaxis'] = {'title': 'Epoch'}
plot_layout['yaxis'] = {'title': 'NELBO'}
iplot(go.Figure(data=data, layout=plot_layout), config=config)

Visualize the latent space

In [5]:
# Display a 2D plot of the classes in the latent space
_, encoded_mean, _ = vae.encoder.predict(x_plot)
latent_scatter_plot = go.Scatter(
    x = encoded_mean[:, 0],
    y = encoded_mean[:, 1],
    mode = 'markers',
    marker = {'color': y_plot, 'colorscale': colorscale},
    hoverinfo = 'text',
    text = np.arange(len(encoded_mean))
)
data = [latent_scatter_plot]
iplot(go.Figure(data=data, layout=layout), config=config)

Train the generator's variance network

Compute the centers

Use k-means clustering on the latent representations to find the centers of our radial basis functions.

In [6]:
from sklearn.cluster import KMeans

# Find the centers of the latent representations
_, encoded_train, _ = vae.encoder.predict(x)
kmeans = KMeans(n_clusters=num_centers, random_state=0).fit(encoded_train)
centers = kmeans.cluster_centers_

# Visualize the centers
center_plot = go.Scatter(
    x = centers[:, 0],
    y = centers[:, 1],
    mode = 'markers',
    marker = {'color': 'red'}
)
data = [latent_scatter_plot, center_plot]
iplot(go.Figure(data=data, layout=layout), config=config)

Compute the bandwidths

This follows equation 11 from the Latent Space Oddity paper.

In [7]:
# Cluster the latent representations
clustering = dict((c_i, []) for c_i in range(num_centers))
for z_i, c_i in zip(encoded_train, kmeans.predict(encoded_train)):
    clustering[c_i].append(z_i)
    
bandwidths = []
for c_i, cluster in clustering.items():
    if cluster:
        diffs = np.array(cluster) - centers[c_i]
        avg_dist = np.mean(np.linalg.norm(diffs, axis=1))
        bandwidth = 0.5 / (a * avg_dist)**2
    else:
        bandwidth = 0
    bandwidths.append(bandwidth)
bandwidths = np.array(bandwidths)

Train the decoder's variance network

using the centers and bandwidths computed above.

In [8]:
# Train the RBF
vae.recompile_for_var_training()
rbf_kernel = rbf.get_weights()[0]
rbf.set_weights([rbf_kernel, centers, bandwidths])

history = vae.model.fit(x,
                        shuffle=True,
                        epochs=200,
                        batch_size=100,
                        verbose=0)

# Plot the losses
data = [go.Scatter(y=history.history['loss'])]
plot_layout = deepcopy(layout)
plot_layout['xaxis'] = {'title': 'Epoch'}
plot_layout['yaxis'] = {'title': 'NELBO'}
iplot(go.Figure(data=data, layout=plot_layout), config=config)
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.

Plot the reconstructions

In [9]:
# Feed the encoder mean to the decoder
decoded, decoded_mean, decoded_var = vae.decoder.predict(encoded_mean)
scatter_plot = go.Scatter(
    x = decoded_mean[:, 0],
    y = decoded_mean[:, 1],
    mode = 'markers',
    marker = {'color': y_plot, 'colorscale': colorscale}
)
data = [scatter_plot]
iplot(go.Figure(data=data, layout=layout), config=config)

Choose two points for computing the distances

In [10]:
z_start_index = 39
z_end_index = 972
z_start, z_end = encoded_mean[z_start_index], encoded_mean[z_end_index]

task_plot = go.Scatter(
    x = np.array([z_start[0], z_end[0]]),
    y = np.array([z_start[1], z_end[1]]),
    mode = 'markers',
    marker = {'color': 'red'}
)

Plot the magnification factor

together with the two points

In [11]:
# Get the mean and std predictors
from src.util import wrap_model_in_float64

_, mean, var = vae.decoder.output
std = Lambda(tf.sqrt)(var)

dec_mean = Model(vae.decoder.input, Flatten()(mean))
dec_std = Model(vae.decoder.input, Flatten()(std))
dec_mean = wrap_model_in_float64(dec_mean)
dec_std = wrap_model_in_float64(dec_std)
In [29]:
from src.plot import plot_magnification_factor

sess = tf.keras.backend.get_session()
heatmap_z1 = np.linspace(-3, 3, 100)
heatmap_z2 = np.linspace(-3, 3, 100)
heatmap = plot_magnification_factor(sess, 
                                    heatmap_z1,
                                    heatmap_z2, 
                                    dec_mean, 
                                    dec_std, 
                                    additional_data=[latent_scatter_plot, task_plot],
                                    layout=layout,
                                    log_scale=True)
Computing Magnification Factors: 100%|██████████| 500/500 [00:01<00:00, 331.92it/s]

Discrete Geodesic

Now that we have our task, we can use the discrete geodesic algorithm from Shao et al. to approximate the geodesic / compute the Riemannian distance between the two points. The algorithm is implemented and explained in 08-discrete-geodesics.ipynb.

In [30]:
%%time
from src.discrete import find_geodesic_discrete

discrete_curve, iterations = find_geodesic_discrete(sess, z_start, z_end, 
                                                    dec_mean, 
                                                    std_generator=dec_std,
                                                    num_nodes=1000,
                                                    learning_rate=0.001,
                                                    max_steps=500,
                                                    save_every=50,
                                                    log_every=50)
print('-' * 20)
Step 0, Length 55.214904, Energy 4143.372779, Max velocity ratio 269.813764
Step 50, Length 46.671003, Energy 2228.297880, Max velocity ratio 188.728705
Step 100, Length 41.429694, Energy 1652.057688, Max velocity ratio 123.674313
Step 150, Length 37.287758, Energy 1240.310495, Max velocity ratio 138.160361
Step 200, Length 34.731490, Energy 978.876862, Max velocity ratio 129.754557
Step 250, Length 32.568597, Energy 805.946758, Max velocity ratio 149.282711
Step 300, Length 30.869219, Energy 686.790371, Max velocity ratio 159.161152
Step 350, Length 29.403932, Energy 595.035285, Max velocity ratio 211.934692
Step 400, Length 28.192545, Energy 530.755320, Max velocity ratio 251.923834
Step 450, Length 27.174228, Energy 475.508310, Max velocity ratio 232.942501
Step 500, Length 26.257712, Energy 432.948446, Max velocity ratio 250.252698
--------------------
CPU times: user 30.6 s, sys: 1.01 s, total: 31.6 s
Wall time: 14.6 s
In [31]:
from src.plot import plot_latent_curve_iterations

plot_latent_curve_iterations(iterations, 
                             [heatmap, latent_scatter_plot, task_plot],
                             layout,
                             step_size=500)

A fair comparison

We see that the curve length steadily decreases for the discrete geodesic algorithm. However, in the plot above we also see that the algorithm "jumps" over regions of large magnification factors in order to avoid "paying" for the regions. You can zoom in on the green points in the plot above to see the jumps. Therefore, the length estimate of the discrete algorithm is flawed. We need a function that removes these jumps and takes equidistant steps in the latent space in order to measure the Riemannian length of a function:

In [32]:
def interpolate(curve, num_nodes):
    # Take equistant steps in the latent space
    latent_lengths = np.linalg.norm(curve[1:] - curve[:-1], axis=1)
    latent_length = np.sum(latent_lengths)
    step_size = latent_length / (num_nodes - 1)
    
    interpolation = [curve[0]]
    i_curve_node = 1
    # Construct the curve step by step
    for i_node in range(num_nodes-2):
        current_step_length = 0
        position = interpolation[-1]
        
        # Find the next interval on the curve
        curve_step_length = np.linalg.norm(curve[i_curve_node] - position)
        while(step_size - current_step_length >= curve_step_length):
            # This curve step fits completely into our current step
            current_step_length += curve_step_length
            position = curve[i_curve_node]
            i_curve_node += 1
            curve_step_length = np.linalg.norm(curve[i_curve_node] - position)
            
        # Take the missing partial step
        relative_step = (step_size - current_step_length) / curve_step_length
        next_node = position + relative_step * (curve[i_curve_node] - 
                                                position)
        interpolation.append(next_node)
    interpolation.append(curve[-1])
    return np.array(interpolation)

Below, we plot the corrected discrete curve with a small number of nodes. For measuring the actual Riemannian length of a curve precisely, we would take 100+ interpolation points. This plot is just to show that the interpolation function works properly.

In [33]:
interpolate_test = interpolate(discrete_curve, 20)
plot_latent_curve_iterations([interpolate_test], 
                             [heatmap, latent_scatter_plot, task_plot], 
                             layout)

Given our interpolate function that corrects the curve, we can now use it to measure the actual Riemannian length of the discrete algorithm's solution. The algorithm started with a Euclidean interpolation (length 55.2), reached a length of 46.6 after 50 steps and then further decreased the curve length to 26.1 after 5000 steps. However, below we see that the actual curve length of the final solution after 5000 steps is 48.6. Therefore, most of the steps of the discrete algorithm seem to have no effect on the actual curve length.

In [34]:
from src.util import get_length_op, get_lengths_op

curve_ph = tf.placeholder(tf.float64, [None, 2])
lengths_op = get_lengths_op(curve_ph, dec_mean, dec_std)
lengths_op = tf.squeeze(lengths_op)

def evaluate_curve(curve, num_nodes=200):
    curve = interpolate(curve, num_nodes)
    lengths = sess.run(lengths_op, feed_dict={curve_ph: curve})
    length = np.sum(lengths)
    print('Curve length: ', length)
In [35]:
evaluate_curve(discrete_curve)
Curve length:  48.417872243521145

Graph approach

Now that we know how the discrete algorithm performs (stochastic Riemannian length of 48.6), we can try the graph approach from 13-graph-geodesics.ipynb.

Create a graph in the latent space

In [36]:
graph_points = encoded_mean
extensions = [graph_points + np.random.randn(*graph_points.shape) 
              for _ in range(2)]
graph_points = np.concatenate([graph_points] + extensions)
print(graph_points.shape)
(3000, 2)
In [37]:
graph_plot = go.Scatter(
    x = graph_points[:, 0],
    y = graph_points[:, 1],
    mode = 'markers',
    marker = {'color': '#EE7830', 'size': 4}
)

data = [graph_plot, heatmap]
iplot(go.Figure(data=data, layout=layout), config=config)

Compute the Riemannian distances of neighboring points

To get the nearest neighbors of each point, we use the get_neighbors function from src.graph. It is explained and defined in 13-graph-geodesics.ipynb.

Given the get_neighbors function, compute the Riemannian distance between each point and each of its neighbors. We approximate the Riemannian distance with a single midpoint for integration: $\int_0^1 \left\| J_{\gamma_t} \dot{\gamma}_t \right\| \mathrm{d}t \approx \left\| J_{\gamma_t} \dot{\gamma}_t \right\|$

In [38]:
from src.graph import get_neighbors
from src.util import get_metric_op
import networkx as nx
from tqdm import tqdm

point_ph = tf.placeholder(tf.float32, [2])
metric_op = get_metric_op(point_ph, dec_mean, dec_std)

# Add all nodes from the graph points
graph = nx.Graph()
for i_point, point in enumerate(graph_points):
    graph.add_node(i_point, pos=point)
    
k = 4

# Compute the Riemannian distance between the kNNs
for i_point, point in enumerate(tqdm(graph_points)):
    neighbor_indices = get_neighbors(i_point, graph_points, k)
    
    for i_neighbor in neighbor_indices:
        neighbor = graph_points[i_neighbor]
        
        # Compute the Riemannian distance with a single midpoint
        middle = point + 0.5 * (neighbor - point) 
        velocity = neighbor - point
        
        # Get the Riemannian metric at the midpoint
        metric = sess.run(metric_op, feed_dict={point_ph: middle})
        length = velocity.T.dot(metric).dot(velocity)
        length = np.sqrt(length)
        graph.add_edge(i_point, i_neighbor, weight=length) 
100%|██████████| 3000/3000 [00:21<00:00, 140.93it/s]

Visualize the graph

and the relative weight of the edges (Riemannian length divided by Euclidean length). Green means a low relative weight, red means a large relative weight.

In [39]:
from src.plot import plot_graph_with_edge_colors

x_range = [-3., 3.]
y_range = [-3., 3.]

subnodes = []
for node in graph.nodes():
    pos = graph.node[node]['pos']
    if (x_range[0] <= pos[0] <= x_range[1] and
        y_range[0] <= pos[1] <= y_range[1]):
        subnodes.append(node)

subgraph = graph.subgraph(subnodes)
graph_plot = plot_graph_with_edge_colors(graph, layout=layout)

Compute the shortest path

between the two points from above.

In [40]:
%%time
from networkx.algorithms.shortest_paths.generic import shortest_path
path = shortest_path(graph, z_start_index, z_end_index, weight='weight')
length = 0
for source, sink in zip(path[:-1], path[1:]):
    length += graph[source][sink]['weight']
print('Path length:', length)
Path length: 4.903976338663146
CPU times: user 24.3 ms, sys: 3.28 ms, total: 27.6 ms
Wall time: 30.3 ms

Visualize the shortest path

In [41]:
from src.plot import plot_graph

# Construct a subgraph from the path
path_graph = nx.Graph()
for point in path:
    path_graph.add_node(point, pos=graph_points[point])
for source, sink in zip(path[:-1], path[1:]):
    weight = graph[source][sink]['weight']
    path_graph.add_edge(source, sink, weight=weight) 

_ = plot_graph(path_graph, layout=layout, edge_width=0, 
               additional_data=graph_plot)

Measure the actual curve length

Since we only computed the Riemannian distance for each edge using a single midpoint, the graph length is not exactly correct. It is not as strongly biased as the discrete geodesic algorithm's length estimate, but we should measure it as well with the interpolate function for a fair comparison.

In [42]:
graph_curve = graph_points[path]
evaluate_curve(graph_curve)
Curve length:  5.38019038379054

Conclusion

  • Discrete algorithm solution: 48.4
  • Graph shortest path: 5.4
In [43]:
# Plot the graph curve
graph_curve_plot = go.Scatter(
    x=graph_curve[:, 0],
    y=graph_curve[:, 1],
    mode='lines',
    line={'width': 5, 'color': '#3CA64D'}
)
# Plot the discrete curve
discrete_curve_plot = go.Scatter(
    x=discrete_curve[:, 0],
    y=discrete_curve[:, 1],
    mode='lines',
    line={'width': 5, 'color': '#d32f2f'}
)
data = [heatmap, latent_scatter_plot, graph_curve_plot, task_plot, 
        discrete_curve_plot]
iplot(go.Figure(data=data, layout=layout), config=config)